import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

df = pd.read_csv('./data/results_matrix.csv')

def format_f1_table(df, index_col='Model', task_col='Task', value_col='F1', se_col='F1_SE', task_order=None, model_order=None):
    """
    Pivot df to wide form with F1 and F1_SE interleaved and scaled to percentages.
    
    Parameters:
    df : pd.DataFrame
        Long-format dataframe containing index_col, task_col, value_col, se_col.
    index_col : str
        Column name for the row index (default 'Model').
    task_col : str
        Column name containing task names (default 'Task').
    value_col : str
        Column name for the main metric (default 'F1').
    se_col : str
        Column name for the standard error metric (default 'F1_SE').
    task_order : list[str] or None
        Desired order of tasks (e.g. ["UKP Stance", "UKP Topic", ...]).
        If None, tasks are kept in the pivot's discovered order.
    model_order : list[str] or None
        Desired order of models for the index (rows). If provided, the result
        will be reindexed to this order (rows not in the list will be dropped).
    
    Returns:
    pd.DataFrame
        Formatted wide dataframe with columns: [Model, Task1, Task1_SE, Task2, Task2_SE, ...]
        All numeric values multiplied by 100 and rounded to 2 decimals.
    """
    # Pivot (use pivot_table to guard against duplicate index/task pairs)
    pivot = df[[index_col, task_col, value_col, se_col]].pivot_table(
        index=index_col,
        columns=task_col,
        values=[value_col, se_col],
        aggfunc='mean'   # change if you prefer a different aggregation
    )

    # Flatten MultiIndex columns: map (metric, task) -> 'task' or 'task_SE'
    def flat_name(metric, task):
        if metric == se_col:
            return f"{task}_SE"
        elif metric == value_col:
            return f"{task}"
        else:
            return f"{metric}_{task}"

    pivot.columns = [flat_name(metric, task) for metric, task in pivot.columns]

    # Reset index to make Model a column
    out = pivot.reset_index()

    # If user didn't provide task_order, infer it from current columns (preserve order)
    if task_order is None:
        seen = []
        for c in out.columns:
            if c == index_col:
                continue
            base = c[:-3] if c.endswith('_SE') else c
            if base not in seen:
                seen.append(base)
        task_order = seen

    # Build interleaved column ordering and ensure missing columns exist (fill with NaN)
    desired_cols = [index_col]
    for t in task_order:
        desired_cols.append(t)
        desired_cols.append(f"{t}_SE")

    for c in desired_cols:
        if c not in out.columns:
            out[c] = pd.NA

    out = out[desired_cols]

    # Reindex rows to the requested model_order (if provided). This will keep the exact order and include NaNs for missing combos.
    if model_order is not None:
        out = out.set_index(index_col).reindex(model_order).reset_index()

    # Multiply numerical columns (all except index_col) by 100 and round to 2 decimals.
    # Convert to numeric where possible to avoid issues with NA types
    numeric_cols = out.columns.drop(index_col)
    out[numeric_cols] = out[numeric_cols].apply(pd.to_numeric, errors='coerce').mul(100).round(2)

    return out

def sort_df(df, colname, order):
    """
    df: The dataframe to sort
    colname: The column to sort by
    order: The order in which values are to appear
    """
    return df.sort_values(by=colname, key=lambda column: column.map(lambda e: order.index(e)))

modnames = {
    'base_nli': 'DeBERTa_{Base}', 
    'large_nli': 'DeBERTa_{Large}',
    'base_debate': 'DEBATE_{Base}',
    'large_debate': 'DEBATE_{Large}',
    'base_modern': 'DEBATE_{Base(MB)}',
    'large_modern': 'DEBATE_{Large(MB)}',
    'llama': 'Llama 3.1_{8B}',
    'llama70b': 'Llama 3.3_{70B}',
    'sonnet': 'Claude 3.5'
}
df['Model'] = df['Model'].replace(modnames)

tasknames = {
    'overall': 'Overall',
    'event extraction': 'Event',
    'hatespeech and toxicity': 'Hatespeech',
    'stance detection': 'Stance',
    'topic classification': 'Topic',
    'PoliStance_Affect': 'Polistance Affect',
    'PoliStance_Affect_QT': 'Polistance Affect Qt',
    'acled_event_entailment': 'ACLED',
    'argument_quality_ranking_entailment': 'Argument Quality Ranking',
    'bill_summary_entailment': 'Bill Summary',
    'dehumanizing_hatespeech_entailment': 'Dehumanizing Hatespeech',
    'dem_rep_party_platform_topics': 'Party Platforms',
    'ibm_claimstance_entailment': 'Claimstance',
    'ibm_claimstance_topic_entailment': 'Claimstance Topic',
    'polistance_issue_tweets': 'Polistance Issue Tweets',
    'scad_event_entailment': 'SCAD',
    'targeted_hatespeech_entailment': 'Targeted Hatespeech',
    'violent_hatespeech_entailment': 'Violent Hatespeech'
}
df['Task'] = df['Task'].replace(tasknames)

# Order for the models to appear in
models = ['DeBERTa_{Base}', 'DeBERTa_{Large}', 'DEBATE_{Base}', 'DEBATE_{Large}', 'DEBATE_{Base(MB)}', 'DEBATE_{Large(MB)}', 'Llama 3.1_{8B}', 'Llama 3.3_{70B}', 'Claude 3.5']

###########
## Table 9
###########
stance_data = ['Polistance Affect', 'Polistance Affect Qt', 'Argument Quality Ranking', 'Claimstance', 'Polistance Issue Tweets']
stance = df[df['Task'].isin(stance_data)]
stance_formatted = format_f1_table(
    stance,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=stance_data,
    model_order=models
)

stance_string = stance_formatted.to_string()

output_file = './tables/table_9.txt'
with open(output_file, 'w') as f:
    f.write(stance_string)

###########
## Table 10
###########
topic_data = ['Bill Summary', 'Party Platforms', 'Claimstance Topic']
topic = df[df['Task'].isin(topic_data)]
topic_formatted = format_f1_table(
    topic,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=topic_data,
    model_order=models
)

topic_string = topic_formatted.to_string()
output_file = './tables/table_10.txt'
with open(output_file, 'w') as f:
    f.write(topic_string)

###########
## Table 11
###########
hate_data = ['Dehumanizing Hatespeech', 'Targeted Hatespeech', 'Violent Hatespeech']
hate = df[df['Task'].isin(hate_data)]
hate_formatted = format_f1_table(
    hate,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=hate_data,
    model_order=models
)

hate_string = hate_formatted.to_string()

output_file = './tables/table_11.txt'
with open(output_file, 'w') as f:
    f.write(hate_string)

###########
## Table 12
###########
event_data = ['ACLED', 'SCAD']
event = df[df['Task'].isin(event_data)]
event_formatted = format_f1_table(
    event,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=event_data,
    model_order=models
)

event_string = event_formatted.to_string()
output_file = './tables/table_12.txt'
with open(output_file, 'w') as f:
    f.write(event_string)